% fcn_steps_solver_endolabor.m
% 
% Solves the structural model with endogenous labor
% 
% "The Past and Future of U.S. Structural Change" 
% Andrew Foerster, Andreas Hornstein, Pierre-Daniel Sarte, Mark Watson
% September 2025
% % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % % 
function [equilibrium,retcode] = fcn_steps_solver_endolabor(varphi,rho,disp_code,py_lambda,z,psi,model)

% Load model parameters
ind1_inv    = model.ind1_inv;
ind2_inv    = model.ind2_inv;
eta_x       = model.eta_x;
epsilon_x   = model.epsilon_x;
rho_x       = model.rho_x;
zeta_x      = model.zeta_x;
ind1_mat    = model.ind1_mat;
ind2_mat    = model.ind2_mat;
eta_m       = model.eta_m;
epsilon_m   = model.epsilon_m;
rho_m       = model.rho_m;
zeta_m      = model.zeta_m;
gamma_y     = model.gamma_y;
beta        = model.beta;
delta       = model.delta;
alpha       = model.alpha;
% s_c         = model.s_c;
ind1_c      = model.ind1_c;
ind2_c      = model.ind2_c;
zeta_c      = model.zeta_c;
Theta_c     = model.Theta_c;
epsilon_c   = model.epsilon_c;
sigma       = model.sigma;
gamma_l     = model.gamma_l;


% -- Convergence Parameters -- %

lambda_start      = 0.0001;             % tatonnement updating parameter
lambda_max        = 1;
lambda_min        = 1e-20;
lambda            = ones(6,1)*lambda_start;

excess_demand_old = ones(6,1)*Inf;
largest_excess_demand_old = 1e+20;

crit        = 1e-5;              % convergence criteria

% -- Initial Guess -- %
eps_start       = 0;
py_lambda_start = py_lambda;
py_lambda       = 1*(py_lambda_start.*(ones(6,1) + eps_start));


% -- STEPS -- %
converge    = Inf;    
iter        = 0;
max_iter    = 200000;       % 200000

tic
while converge > crit && iter < max_iter
    iter = iter + 1;

    py          = py_lambda(1:5);
    lambda_c    = py_lambda(6);
    
    % STEP 1: solve p^x = P^X(p^y)
    px = NaN(5,1);
    for j = 1:5
        temp_sum = 0;
        for i = ind1_inv
            temp_sum = temp_sum + eta_x(i,j)^epsilon_x(j)*py(i)^(1-epsilon_x(j));
        end
        temp_sum = temp_sum^(rho_x(j)/(1-epsilon_x(j)));
    
        temp_prod = 1;
        for i = ind2_inv
            temp_prod = temp_prod*py(i)^(zeta_x(i,j)*(1-rho_x(j)));
        end
        px(j) = temp_sum*temp_prod;
    end

    % STEP 2: solve p^m = P^(p^y)
    pm = NaN(5,1);
    for j = 1:5
        temp_sum = 0;
        for i = ind1_mat
            temp_sum = temp_sum + eta_m(i,j)^epsilon_m(j)*py(i)^(1-epsilon_m(j));
        end
        temp_sum = temp_sum^(rho_m(j)/(1-epsilon_m(j)));
    
        temp_prod = 1;
        for i = ind2_mat
            temp_prod = temp_prod*py(i)^(zeta_m(i,j)*(1-rho_m(j)));
        end
        pm(j) = temp_sum*temp_prod;
    end


    % STEP 3: solve for p^v from p^y = P^Y(p^v,p^m) 
    pv   = NaN(5,1);
    for j = 1:5
        pv(j) = (py(j)/pm(j)^(1-gamma_y(j)))^(1/gamma_y(j));
    end
    pv = real(pv);

    % STEP 4: solve for u from Euler equation
    u = px.*(1/beta - 1 + delta);

    % STEP 5: solve for capital-labor ratio from FOCs for capital
    klratio = (u./pv./z.*(alpha./(1-alpha)).^(alpha-1)).^(1./(alpha-1));
    

    % STEP 6: Solve for wage rates from FOCs for labor
    w = pv.*z.*klratio.^alpha.*(alpha./(1-alpha)).^(alpha);
    
    % STEP 6': Solve for labor
    l = (lambda_c*w./varphi).^(1/gamma_l);


    % STEP 7: capital-labor ratios, factor prices, and value added prices
    % should satisfy FOC
    % this step is a check
    
    % STEP 8: conditional on k/l and l, obtain k
    k = klratio.*l;
    

    % STEP 9: solve for x from steady state capital accumulation
    x = delta.*k;
    

    % STEP 10: solve v = z*V(k,l)
    v = z.*klratio.^alpha.*l./(alpha.^alpha)./(1-alpha).^(1-alpha);
    

    % STEP 11: Solve for y from v =dP^Y(p^v,p^m)/dp^v * y
    y = NaN(5,1);
    for j = 1:5
        y(j) = 1/gamma_y(j)*pv(j)^(1-gamma_y(j))*pm(j)^(gamma_y(j)-1)*v(j);
    end
    

    % STEP 12: solve for m = dP^Y(p^v,p^m)/dp^m * y
    m = NaN(5,1);
    for j = 1:5
        m(j) = (1-gamma_y(j))*pv(j)^gamma_y(j)*pm(j)^-gamma_y(j)*y(j);
    end

    %  solve for m{i,j} = dP^M(p^y)/dp^y *m(j)
    mij = NaN(5,5);
    for j = 1:5
        pmkj = sum(eta_m(ind1_mat,j).^epsilon_m(j).*py(ind1_mat).^(1-epsilon_m(j))).^(1/(1-epsilon_m(j)));
        pmnj = prod(py(ind2_mat).^zeta_m(ind2_mat,j));
        for i = ind1_mat
            mij(i,j) = rho_m(j)*pmnj^(1-rho_m(j))*pmkj^(rho_m(j)-1)*sum(eta_m(ind1_mat,j).^epsilon_m(j).*py(ind1_mat).^(1-epsilon_m(j)))^(epsilon_m(j)/(1-epsilon_m(j)))*eta_m(i,j)^epsilon_m(j)*py(i)^(-epsilon_m(j))*m(j);
        end
        for i = ind2_mat
            mij(i,j) = (1-rho_m(j))*pmkj^rho_m(j)*pmnj^(-rho_m(j))*zeta_m(i,j)*prod(py(ind2_mat).^zeta_m(ind2_mat,j))/py(i)*m(j);
        end
    end
    
    
    % STEP 13: Solve x(i,j) = dP(j)^x(p^y)/dp^y(i) *x(j)
    xij  = NaN(5,5);
    for j = 1:5
        pxkj = sum(eta_x(ind1_inv,j).^epsilon_x(j).*py(ind1_inv).^(1-epsilon_x(j))).^(1/(1-epsilon_x(j)));
        pxnj = prod(py(ind2_inv).^zeta_x(ind2_inv,j));
        for i = ind1_inv
            xij(i,j) = rho_x(j)*pxnj^(1-rho_x(j))*pxkj^(rho_x(j)-1)*sum(eta_x(ind1_inv,j).^epsilon_x(j).*py(ind1_inv).^(1-epsilon_x(j)))^(epsilon_x(j)/(1-epsilon_x(j)))*eta_x(i,j)^epsilon_x(j)*py(i)^(-epsilon_x(j))*x(j);
        end
        for i = ind2_inv
            xij(i,j) = (1-rho_x(j))*pxkj^rho_x(j)*pxnj^(-rho_x(j))*zeta_x(i,j)*prod(py(ind2_inv).^zeta_x(ind2_inv,j))/py(i)*x(j);
        end
    end
    
    
    % STEP 14: solve for consumption from resource constraints
    cs = y - sum(mij,2) - sum(xij,2) - psi.*pv.*v./py;
    cs = max(cs,1e-10);
    
    % STEP 15: construct aggregate nominal expenditures
    e = sum(py.*cs);
    
    % STEP 16: solve for consumption
    options = optimoptions('fsolve','Display','off','OptimalityTolerance',1e-10);
    cn = fsolve(@(cn) step16solver(cn,Theta_c,py,ind2_c,epsilon_c,sigma,e,rho),rho*e,options);


    % STEP 17: solve for real consumption index
    % en = sum(Theta_c(ind2_c).*py(ind2_c).^(1-sigma).*cn.^(epsilon_c(ind2_c).*(1-sigma)))^(1/(1-sigma));
    
    % STEP 18: 
    num = sum(epsilon_c(ind2_c).*Theta_c(ind2_c).*py(ind2_c).^(1-sigma).*cn.^(epsilon_c(ind2_c)*(1-sigma)));
    den = sum(Theta_c(ind2_c).*py(ind2_c).^(1-sigma).*cn.^(epsilon_c(ind2_c)*(1-sigma)));
    eta_epsilon = num/den;

    % STEP 19: ek
    ek = (rho*eta_epsilon)/(rho*eta_epsilon+(1-rho))*e;

    % STEP 20: ck
    ck = ek/prod(py(ind1_c).^zeta_c(ind1_c));

    % STEP 21: obtain demand for goods c^d(i) = dE(p^y,c)/dpy(i)
    cd = NaN(5,1);
    for j = ind1_c
        cd(j) = zeta_c(j)*prod(py(ind1_c).^zeta_c(ind1_c))/py(j)*ck;
    end
    for j = ind2_c
        cd(j) = sum(Theta_c(ind2_c).*py(ind2_c).^(1-sigma).*cn.^(epsilon_c(ind2_c)*(1-sigma)))^(sigma/(1-sigma))*Theta_c(j)*cn^(epsilon_c(j)*(1-sigma))*py(j)^-sigma;
    end

    excess_demand = [cd - cs;
                    rho/prod(py(ind1_c).^zeta_c(ind1_c))*(cn/ck)^(1-rho) - lambda_c];

    if norm(imag(excess_demand)) > 0
        warning('Imaginary Excess Demand')
        save temp_error;
        pause;
        iter = max_iter;
    end

    converge = norm(excess_demand);

    if disp_code == 1
        disp(['Convergence norm for excess demand = ' num2str(converge)])
    end
    
    largest_excess_demand = max(abs(excess_demand));
    changeInLargestExcessDemand = largest_excess_demand - largest_excess_demand_old;
    acceptGuess = changeInLargestExcessDemand < 0;

    if iter > max_iter/2
        lambda = lambda_start;
    else
        if acceptGuess
    
             for j = 1:6
                 if abs(excess_demand(j)) > abs(excess_demand_old(j))
                     lambda(j) = max(lambda(j) - .0001, lambda_min) ;
                 end
                 if abs(excess_demand(j)) < abs(excess_demand_old(j))
                     lambda(j) = min(lambda(j) + .0001, lambda_max);
                 end
             end
    
         else
              lambda = ones(6,1)*lambda_start;
         end
    
        excess_demand_old = excess_demand;
        largest_excess_demand_old = largest_excess_demand;
    end

    py_lambda = max(py_lambda + lambda.*excess_demand,1e-10);
    

end

% eps_start
% iter

% py_norm = py./py(4);
% py_start_py = [py_start py py_norm]

% max_excess_demand = max(excess_demand)


if iter == max_iter
    warning(['Convergence norm = ' num2str(converge) ', iterations = ' num2str(iter)]);
    retcode = 0;
else
    retcode = 1;
end

equilibrium.py      = py;
equilibrium.pv      = pv;
equilibrium.px      = px;
equilibrium.pm      = pm;
equilibrium.y       = y;
equilibrium.v       = v;
equilibrium.x       = x;
equilibrium.xij     = xij;
equilibrium.m       = m;
equilibrium.mij     = mij;
equilibrium.c       = cs;
equilibrium.e       = e;
equilibrium.psi     = psi;
equilibrium.z       = z;
equilibrium.varphi  = varphi;
equilibrium.rho     = rho;
equilibrium.lambdac = lambda_c;
equilibrium.l       = l;
equilibrium.ck      = ck;
equilibrium.cn      = cn;
equilibrium.w       = w;

end


function resid = step16solver(cn,omega,py,ind2,epsilon,sigma,e,rho)
    num1 = rho*sum(epsilon(ind2).*omega(ind2).*py(ind2).^(1-sigma).*cn.^(epsilon(ind2)*(1-sigma))) + (1-rho)*sum(omega(ind2).*py(ind2).^(1-sigma).*cn.^(epsilon(ind2)*(1-sigma)));
    num2 = (sum(omega(ind2).*py(ind2).^(1-sigma).*cn.^(epsilon(ind2)*(1-sigma))))^(sigma/(1-sigma));
    resid = num1/(1-rho)*num2/e - 1;
end